Cashew Disease Classification Using Convolutional Neural Networks (CNN)¶

In this project, we will implement a machine learning model using Convolutional Neural Networks (CNN) to classify diseases in cashew trees based on images of their leaves. The goal is to automate the process of detecting diseases, which can help farmers manage crop health and improve productivity.

Objectives¶

  • Preprocess and prepare a dataset of images containing cashew leaves affected by various diseases.
  • Implement a CNN model to learn patterns in the images and classify them into different categories.
  • Evaluate the model's performance using metrics such as accuracy, precision, recall, and F1-score.

Key Steps¶

  1. Data Collection: We will gather images of healthy and diseased cashew leaves. These images will serve as our dataset.
  2. Data Preprocessing: Resize, normalize, and augment the images to prepare them for training.
  3. Model Building: Implement using Keras/TensorFlow to classify the images.
  4. Training: Train the CNN model on the preprocessed dataset.
  5. Evaluation: Evaluate the model's performance using test data and metrics.
  6. Prediction: Test the model with new, unseen images to predict disease classification.

By the end of this project, we aim to develop a robust model that can assist in early disease detection for cashew trees, ultimately aiding in better crop management and yield.

Let's get started by loading the dataset and performing some initial exploration!

1. Setup and importing necessary libraries¶

1.1 Ignoring warnings shown in the notebook¶

In [22]:
import warnings
warnings.filterwarnings('ignore')

1.2 Installing gdown in the environment¶

Gdown : Used to download a public file/folder from Google Drive. Gdown provides what curl/wget doesn't for Google Drive.

In [23]:
!pip install gdown
Requirement already satisfied: gdown in /opt/conda/lib/python3.10/site-packages (5.2.0)
Requirement already satisfied: beautifulsoup4 in /opt/conda/lib/python3.10/site-packages (from gdown) (4.12.3)
Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from gdown) (3.15.1)
Requirement already satisfied: requests[socks] in /opt/conda/lib/python3.10/site-packages (from gdown) (2.32.3)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from gdown) (4.66.4)
Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.10/site-packages (from beautifulsoup4->gdown) (2.5)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (1.26.18)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (2024.8.30)
Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (1.7.1)

1.3 Necessary Imports¶

In [24]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import cv2
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm
import os
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard, ModelCheckpoint
from sklearn.metrics import classification_report,confusion_matrix,ConfusionMatrixDisplay, accuracy_score
import ipywidgets as widgets
import io
from PIL import Image
from IPython.display import display,clear_output

1.4 Setting up color pallets¶

In [25]:
colors_dark = ["#1F1F1F", "#313131", '#636363', '#AEAEAE', '#DADADA']
colors_red = ["#331313", "#582626", '#9E1717', '#D35151', '#E9B4B4']
colors_green = ['#01411C','#4B6F44','#4F7942','#74C365','#D0F0C0']

sns.palplot(colors_dark)
sns.palplot(colors_green)
sns.palplot(colors_red)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

1.5 Setting up label/class names¶

In [26]:
labels = ['anthracnose','gumosis','healthy','leaf miner',"red rust"]
CLASS_NAMES = labels

2. Exploratory data analysis¶

2.1 Identifying the proportion of each class.¶

In [27]:
train_path = '/kaggle/input/cashew-image-dataset/Cashew/train_set'
test_path = '/kaggle/input/cashew-image-dataset/Cashew/test_set'

category_counts = {}

for category in os.listdir(train_path):
    category_folder = os.path.join(train_path, category)
    if os.path.isdir(category_folder):
        train_count = len(os.listdir(category_folder))
        category_counts[category] = train_count

for category in os.listdir(test_path):
    category_folder = os.path.join(test_path, category)
    if os.path.isdir(category_folder):
        test_count = len(os.listdir(category_folder))
        category_counts[category] = category_counts.get(category, 0) + test_count

labels = list(category_counts.keys())
counts = list(category_counts.values())

colors = sns.color_palette("pastel", len(labels))

plt.figure(figsize=(8, 8))
plt.pie(counts, labels=labels, autopct='%1.1f%%', startangle=90, colors=colors)
plt.title("Image Distribution Across Categories", fontsize=16)
plt.show()
No description has been provided for this image

Results : As we can see, data set is pretty much balanced except the gumosis class.

2.2 Identifying the number of images in each class¶

In [28]:
plt.figure(figsize=(12, 6))
sns.barplot(x=counts, y=labels, palette="viridis", orient="h")

for i, count in enumerate(counts):
    plt.text(count + 0.5, i, str(count), va='center', fontsize=10, fontweight='bold', color='black')


plt.title("Number of Images Per Category", fontsize=18, fontweight='bold')
plt.xlabel("Number of Images", fontsize=14)
plt.ylabel("Categories", fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.grid(axis='x', linestyle='--', alpha=0.7)
plt.tight_layout()

# Show the chart
plt.show()
No description has been provided for this image

2.3 Identifying the Image Dimension Distribution¶

In [29]:
image_dimensions = []

for category in os.listdir(train_path):
    category_folder = os.path.join(train_path, category)
    if os.path.isdir(category_folder):
        for img in os.listdir(category_folder):
            img_path = os.path.join(category_folder, img)
            img_dim = plt.imread(img_path).shape
            image_dimensions.append(img_dim[:2])  # (height, width)

# Convert to DataFrame for analysis
import pandas as pd
df_dims = pd.DataFrame(image_dimensions, columns=["Height", "Width"])
sns.boxplot(data=df_dims)
plt.title("Image Dimension Distribution")
plt.show()
No description has been provided for this image

2.4 Visualizing samples from each class¶

In [30]:
X_train = []
y_train = []
image_size = 224
for i in labels:
    folderPath = os.path.join('/kaggle/input/cashew-image-dataset/Cashew','train_set',i)
    for j in tqdm(os.listdir(folderPath)):
        img = cv2.imread(os.path.join(folderPath,j))
        img = cv2.resize(img,(image_size, image_size))
        X_train.append(img)
        y_train.append(i)
        
for i in labels:
    folderPath = os.path.join('/kaggle/input/cashew-image-dataset/Cashew','test_set',i)
    for j in tqdm(os.listdir(folderPath)):
        img = cv2.imread(os.path.join(folderPath,j))
        img = cv2.resize(img,(image_size,image_size))
        X_train.append(img)
        y_train.append(i)
        
X_train = np.array(X_train)
y_train = np.array(y_train)
100%|██████████| 4751/4751 [00:12<00:00, 367.07it/s]
100%|██████████| 3102/3102 [00:08<00:00, 384.12it/s]
100%|██████████| 5877/5877 [00:15<00:00, 374.38it/s]
100%|██████████| 1714/1714 [00:04<00:00, 349.30it/s]
100%|██████████| 3466/3466 [00:09<00:00, 361.17it/s]
100%|██████████| 1815/1815 [00:06<00:00, 261.56it/s]
100%|██████████| 1838/1838 [00:06<00:00, 275.41it/s]
100%|██████████| 1336/1336 [00:04<00:00, 294.36it/s]
100%|██████████| 425/425 [00:01<00:00, 277.75it/s]
100%|██████████| 1487/1487 [00:05<00:00, 292.17it/s]
In [31]:
k=0
fig, ax = plt.subplots(1,5,figsize=(20,20))
fig.text(s='Sample Image From Each Class',size=18,fontweight='bold',
             fontname='monospace',color=colors_dark[1],y=0.62,x=0.4,alpha=0.8)
for i in labels:
    j=0
    while True :
        if y_train[j]==i:
            ax[k].imshow(X_train[j])
            ax[k].set_title(y_train[j])
            ax[k].axis('off')
            k+=1
            break
        j+=1
No description has been provided for this image

3. Getting the training and testing data ready¶

3.1 Shuffling the data for the randomness¶

In [32]:
X_train, y_train = shuffle(X_train,y_train, random_state=101)
In [33]:
X_train.shape # Shape of the X_train
# (Number of images , image size, image size, number of color channels)
Out[33]:
(25811, 224, 224, 3)

3.2 Train and test split¶

Test size : 20%¶

In [34]:
X_train,X_test,y_train,y_test = train_test_split(X_train,y_train, test_size=0.2,random_state=101)

3.3 One Hot Encoding of the labels¶

In [35]:
y_train_new = []
for i in y_train:
    y_train_new.append(labels.index(i))
y_train = y_train_new
y_train = tf.keras.utils.to_categorical(y_train)


y_test_new = []
for i in y_test:
    y_test_new.append(labels.index(i))
y_test = y_test_new
y_test = tf.keras.utils.to_categorical(y_test)

4. Setting up the neural net and training¶

4.1 Transfer Learning¶

Deep convolutional neural network models may take days or even weeks to train on very large datasets.

A way to short-cut this process is to re-use the model weights from pre-trained models that were developed for standard computer vision benchmark datasets, such as the ImageNet image recognition tasks. Top performing models can be downloaded and used directly, or integrated into a new model for your own computer vision problems.

In this Project, We'll be using the EfficientNetB0 model which will use the weights from the ImageNet dataset.

The include_top parameter is set to False so that the network doesn't include the top layer/output layer from the pre-built model which allows us to add our own output layer depending upon our use case!

In [36]:
effnet = EfficientNetB0(weights='imagenet',include_top=False,input_shape=(image_size,image_size,3))

4.2 Layers¶

GlobalAveragePooling2D -> This layer acts similar to the Max Pooling layer in CNNs, the only difference being is that it uses the Average values instead of the Max value while pooling. This really helps in decreasing the computational load on the machine while training.

Dropout -> This layer omits some of the neurons at each step from the layer making the neurons more independent from the neibouring neurons. It helps in avoiding overfitting. Neurons to be ommitted are selected at random. The rate parameter is the liklihood of a neuron activation being set to 0, thus dropping out the neuron

Dense -> This is the output layer which classifies the image into 1 of the 5 possible classes. It uses the softmax function which is a generalization of the sigmoid function.

In [37]:
model = effnet.output
model = tf.keras.layers.GlobalAveragePooling2D()(model)
model = tf.keras.layers.Dropout(rate=0.5)(model)
model = tf.keras.layers.Dense(5,activation='softmax')(model)
model = tf.keras.models.Model(inputs=effnet.input, outputs = model)

4.3 Model summary¶

In [38]:
model.summary()
Model: "functional_3"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input_layer_1       │ (None, 224, 224,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ rescaling_2         │ (None, 224, 224,  │          0 │ input_layer_1[0]… │
│ (Rescaling)         │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ normalization_1     │ (None, 224, 224,  │          7 │ rescaling_2[0][0] │
│ (Normalization)     │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ rescaling_3         │ (None, 224, 224,  │          0 │ normalization_1[… │
│ (Rescaling)         │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv_pad       │ (None, 225, 225,  │          0 │ rescaling_3[0][0] │
│ (ZeroPadding2D)     │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv (Conv2D)  │ (None, 112, 112,  │        864 │ stem_conv_pad[0]… │
│                     │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_bn             │ (None, 112, 112,  │        128 │ stem_conv[0][0]   │
│ (BatchNormalizatio… │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_activation     │ (None, 112, 112,  │          0 │ stem_bn[0][0]     │
│ (Activation)        │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_dwconv      │ (None, 112, 112,  │        288 │ stem_activation[… │
│ (DepthwiseConv2D)   │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_bn          │ (None, 112, 112,  │        128 │ block1a_dwconv[0… │
│ (BatchNormalizatio… │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_activation  │ (None, 112, 112,  │          0 │ block1a_bn[0][0]  │
│ (Activation)        │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_squeeze  │ (None, 32)        │          0 │ block1a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_reshape  │ (None, 1, 1, 32)  │          0 │ block1a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_reduce   │ (None, 1, 1, 8)   │        264 │ block1a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_expand   │ (None, 1, 1, 32)  │        288 │ block1a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_se_excite   │ (None, 112, 112,  │          0 │ block1a_activati… │
│ (Multiply)          │ 32)               │            │ block1a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_project_co… │ (None, 112, 112,  │        512 │ block1a_se_excit… │
│ (Conv2D)            │ 16)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block1a_project_bn  │ (None, 112, 112,  │         64 │ block1a_project_… │
│ (BatchNormalizatio… │ 16)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_expand_conv │ (None, 112, 112,  │      1,536 │ block1a_project_… │
│ (Conv2D)            │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_expand_bn   │ (None, 112, 112,  │        384 │ block2a_expand_c… │
│ (BatchNormalizatio… │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_expand_act… │ (None, 112, 112,  │          0 │ block2a_expand_b… │
│ (Activation)        │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_dwconv_pad  │ (None, 113, 113,  │          0 │ block2a_expand_a… │
│ (ZeroPadding2D)     │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_dwconv      │ (None, 56, 56,    │        864 │ block2a_dwconv_p… │
│ (DepthwiseConv2D)   │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_bn          │ (None, 56, 56,    │        384 │ block2a_dwconv[0… │
│ (BatchNormalizatio… │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_activation  │ (None, 56, 56,    │          0 │ block2a_bn[0][0]  │
│ (Activation)        │ 96)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_squeeze  │ (None, 96)        │          0 │ block2a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_reshape  │ (None, 1, 1, 96)  │          0 │ block2a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_reduce   │ (None, 1, 1, 4)   │        388 │ block2a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_expand   │ (None, 1, 1, 96)  │        480 │ block2a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_se_excite   │ (None, 56, 56,    │          0 │ block2a_activati… │
│ (Multiply)          │ 96)               │            │ block2a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_project_co… │ (None, 56, 56,    │      2,304 │ block2a_se_excit… │
│ (Conv2D)            │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2a_project_bn  │ (None, 56, 56,    │         96 │ block2a_project_… │
│ (BatchNormalizatio… │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_expand_conv │ (None, 56, 56,    │      3,456 │ block2a_project_… │
│ (Conv2D)            │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_expand_bn   │ (None, 56, 56,    │        576 │ block2b_expand_c… │
│ (BatchNormalizatio… │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_expand_act… │ (None, 56, 56,    │          0 │ block2b_expand_b… │
│ (Activation)        │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_dwconv      │ (None, 56, 56,    │      1,296 │ block2b_expand_a… │
│ (DepthwiseConv2D)   │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_bn          │ (None, 56, 56,    │        576 │ block2b_dwconv[0… │
│ (BatchNormalizatio… │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_activation  │ (None, 56, 56,    │          0 │ block2b_bn[0][0]  │
│ (Activation)        │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_squeeze  │ (None, 144)       │          0 │ block2b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_reshape  │ (None, 1, 1, 144) │          0 │ block2b_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_reduce   │ (None, 1, 1, 6)   │        870 │ block2b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_expand   │ (None, 1, 1, 144) │      1,008 │ block2b_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_se_excite   │ (None, 56, 56,    │          0 │ block2b_activati… │
│ (Multiply)          │ 144)              │            │ block2b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_project_co… │ (None, 56, 56,    │      3,456 │ block2b_se_excit… │
│ (Conv2D)            │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_project_bn  │ (None, 56, 56,    │         96 │ block2b_project_… │
│ (BatchNormalizatio… │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_drop        │ (None, 56, 56,    │          0 │ block2b_project_… │
│ (Dropout)           │ 24)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block2b_add (Add)   │ (None, 56, 56,    │          0 │ block2b_drop[0][… │
│                     │ 24)               │            │ block2a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_expand_conv │ (None, 56, 56,    │      3,456 │ block2b_add[0][0] │
│ (Conv2D)            │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_expand_bn   │ (None, 56, 56,    │        576 │ block3a_expand_c… │
│ (BatchNormalizatio… │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_expand_act… │ (None, 56, 56,    │          0 │ block3a_expand_b… │
│ (Activation)        │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_dwconv_pad  │ (None, 59, 59,    │          0 │ block3a_expand_a… │
│ (ZeroPadding2D)     │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_dwconv      │ (None, 28, 28,    │      3,600 │ block3a_dwconv_p… │
│ (DepthwiseConv2D)   │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_bn          │ (None, 28, 28,    │        576 │ block3a_dwconv[0… │
│ (BatchNormalizatio… │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_activation  │ (None, 28, 28,    │          0 │ block3a_bn[0][0]  │
│ (Activation)        │ 144)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_squeeze  │ (None, 144)       │          0 │ block3a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_reshape  │ (None, 1, 1, 144) │          0 │ block3a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_reduce   │ (None, 1, 1, 6)   │        870 │ block3a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_expand   │ (None, 1, 1, 144) │      1,008 │ block3a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_se_excite   │ (None, 28, 28,    │          0 │ block3a_activati… │
│ (Multiply)          │ 144)              │            │ block3a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_project_co… │ (None, 28, 28,    │      5,760 │ block3a_se_excit… │
│ (Conv2D)            │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3a_project_bn  │ (None, 28, 28,    │        160 │ block3a_project_… │
│ (BatchNormalizatio… │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_expand_conv │ (None, 28, 28,    │      9,600 │ block3a_project_… │
│ (Conv2D)            │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_expand_bn   │ (None, 28, 28,    │        960 │ block3b_expand_c… │
│ (BatchNormalizatio… │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_expand_act… │ (None, 28, 28,    │          0 │ block3b_expand_b… │
│ (Activation)        │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_dwconv      │ (None, 28, 28,    │      6,000 │ block3b_expand_a… │
│ (DepthwiseConv2D)   │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_bn          │ (None, 28, 28,    │        960 │ block3b_dwconv[0… │
│ (BatchNormalizatio… │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_activation  │ (None, 28, 28,    │          0 │ block3b_bn[0][0]  │
│ (Activation)        │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_squeeze  │ (None, 240)       │          0 │ block3b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_reshape  │ (None, 1, 1, 240) │          0 │ block3b_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_reduce   │ (None, 1, 1, 10)  │      2,410 │ block3b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_expand   │ (None, 1, 1, 240) │      2,640 │ block3b_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_se_excite   │ (None, 28, 28,    │          0 │ block3b_activati… │
│ (Multiply)          │ 240)              │            │ block3b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_project_co… │ (None, 28, 28,    │      9,600 │ block3b_se_excit… │
│ (Conv2D)            │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_project_bn  │ (None, 28, 28,    │        160 │ block3b_project_… │
│ (BatchNormalizatio… │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_drop        │ (None, 28, 28,    │          0 │ block3b_project_… │
│ (Dropout)           │ 40)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block3b_add (Add)   │ (None, 28, 28,    │          0 │ block3b_drop[0][… │
│                     │ 40)               │            │ block3a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_expand_conv │ (None, 28, 28,    │      9,600 │ block3b_add[0][0] │
│ (Conv2D)            │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_expand_bn   │ (None, 28, 28,    │        960 │ block4a_expand_c… │
│ (BatchNormalizatio… │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_expand_act… │ (None, 28, 28,    │          0 │ block4a_expand_b… │
│ (Activation)        │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_dwconv_pad  │ (None, 29, 29,    │          0 │ block4a_expand_a… │
│ (ZeroPadding2D)     │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_dwconv      │ (None, 14, 14,    │      2,160 │ block4a_dwconv_p… │
│ (DepthwiseConv2D)   │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_bn          │ (None, 14, 14,    │        960 │ block4a_dwconv[0… │
│ (BatchNormalizatio… │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_activation  │ (None, 14, 14,    │          0 │ block4a_bn[0][0]  │
│ (Activation)        │ 240)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_squeeze  │ (None, 240)       │          0 │ block4a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_reshape  │ (None, 1, 1, 240) │          0 │ block4a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_reduce   │ (None, 1, 1, 10)  │      2,410 │ block4a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_expand   │ (None, 1, 1, 240) │      2,640 │ block4a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_se_excite   │ (None, 14, 14,    │          0 │ block4a_activati… │
│ (Multiply)          │ 240)              │            │ block4a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_project_co… │ (None, 14, 14,    │     19,200 │ block4a_se_excit… │
│ (Conv2D)            │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4a_project_bn  │ (None, 14, 14,    │        320 │ block4a_project_… │
│ (BatchNormalizatio… │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_expand_conv │ (None, 14, 14,    │     38,400 │ block4a_project_… │
│ (Conv2D)            │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_expand_bn   │ (None, 14, 14,    │      1,920 │ block4b_expand_c… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_expand_act… │ (None, 14, 14,    │          0 │ block4b_expand_b… │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_dwconv      │ (None, 14, 14,    │      4,320 │ block4b_expand_a… │
│ (DepthwiseConv2D)   │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_bn          │ (None, 14, 14,    │      1,920 │ block4b_dwconv[0… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_activation  │ (None, 14, 14,    │          0 │ block4b_bn[0][0]  │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_squeeze  │ (None, 480)       │          0 │ block4b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_reshape  │ (None, 1, 1, 480) │          0 │ block4b_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_reduce   │ (None, 1, 1, 20)  │      9,620 │ block4b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_expand   │ (None, 1, 1, 480) │     10,080 │ block4b_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_se_excite   │ (None, 14, 14,    │          0 │ block4b_activati… │
│ (Multiply)          │ 480)              │            │ block4b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_project_co… │ (None, 14, 14,    │     38,400 │ block4b_se_excit… │
│ (Conv2D)            │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_project_bn  │ (None, 14, 14,    │        320 │ block4b_project_… │
│ (BatchNormalizatio… │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_drop        │ (None, 14, 14,    │          0 │ block4b_project_… │
│ (Dropout)           │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4b_add (Add)   │ (None, 14, 14,    │          0 │ block4b_drop[0][… │
│                     │ 80)               │            │ block4a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_expand_conv │ (None, 14, 14,    │     38,400 │ block4b_add[0][0] │
│ (Conv2D)            │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_expand_bn   │ (None, 14, 14,    │      1,920 │ block4c_expand_c… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_expand_act… │ (None, 14, 14,    │          0 │ block4c_expand_b… │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_dwconv      │ (None, 14, 14,    │      4,320 │ block4c_expand_a… │
│ (DepthwiseConv2D)   │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_bn          │ (None, 14, 14,    │      1,920 │ block4c_dwconv[0… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_activation  │ (None, 14, 14,    │          0 │ block4c_bn[0][0]  │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_squeeze  │ (None, 480)       │          0 │ block4c_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_reshape  │ (None, 1, 1, 480) │          0 │ block4c_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_reduce   │ (None, 1, 1, 20)  │      9,620 │ block4c_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_expand   │ (None, 1, 1, 480) │     10,080 │ block4c_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_se_excite   │ (None, 14, 14,    │          0 │ block4c_activati… │
│ (Multiply)          │ 480)              │            │ block4c_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_project_co… │ (None, 14, 14,    │     38,400 │ block4c_se_excit… │
│ (Conv2D)            │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_project_bn  │ (None, 14, 14,    │        320 │ block4c_project_… │
│ (BatchNormalizatio… │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_drop        │ (None, 14, 14,    │          0 │ block4c_project_… │
│ (Dropout)           │ 80)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block4c_add (Add)   │ (None, 14, 14,    │          0 │ block4c_drop[0][… │
│                     │ 80)               │            │ block4b_add[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_expand_conv │ (None, 14, 14,    │     38,400 │ block4c_add[0][0] │
│ (Conv2D)            │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_expand_bn   │ (None, 14, 14,    │      1,920 │ block5a_expand_c… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_expand_act… │ (None, 14, 14,    │          0 │ block5a_expand_b… │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_dwconv      │ (None, 14, 14,    │     12,000 │ block5a_expand_a… │
│ (DepthwiseConv2D)   │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_bn          │ (None, 14, 14,    │      1,920 │ block5a_dwconv[0… │
│ (BatchNormalizatio… │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_activation  │ (None, 14, 14,    │          0 │ block5a_bn[0][0]  │
│ (Activation)        │ 480)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_squeeze  │ (None, 480)       │          0 │ block5a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_reshape  │ (None, 1, 1, 480) │          0 │ block5a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_reduce   │ (None, 1, 1, 20)  │      9,620 │ block5a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_expand   │ (None, 1, 1, 480) │     10,080 │ block5a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_se_excite   │ (None, 14, 14,    │          0 │ block5a_activati… │
│ (Multiply)          │ 480)              │            │ block5a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_project_co… │ (None, 14, 14,    │     53,760 │ block5a_se_excit… │
│ (Conv2D)            │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5a_project_bn  │ (None, 14, 14,    │        448 │ block5a_project_… │
│ (BatchNormalizatio… │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_expand_conv │ (None, 14, 14,    │     75,264 │ block5a_project_… │
│ (Conv2D)            │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_expand_bn   │ (None, 14, 14,    │      2,688 │ block5b_expand_c… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_expand_act… │ (None, 14, 14,    │          0 │ block5b_expand_b… │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_dwconv      │ (None, 14, 14,    │     16,800 │ block5b_expand_a… │
│ (DepthwiseConv2D)   │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_bn          │ (None, 14, 14,    │      2,688 │ block5b_dwconv[0… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_activation  │ (None, 14, 14,    │          0 │ block5b_bn[0][0]  │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_squeeze  │ (None, 672)       │          0 │ block5b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_reshape  │ (None, 1, 1, 672) │          0 │ block5b_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_reduce   │ (None, 1, 1, 28)  │     18,844 │ block5b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_expand   │ (None, 1, 1, 672) │     19,488 │ block5b_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_se_excite   │ (None, 14, 14,    │          0 │ block5b_activati… │
│ (Multiply)          │ 672)              │            │ block5b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_project_co… │ (None, 14, 14,    │     75,264 │ block5b_se_excit… │
│ (Conv2D)            │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_project_bn  │ (None, 14, 14,    │        448 │ block5b_project_… │
│ (BatchNormalizatio… │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_drop        │ (None, 14, 14,    │          0 │ block5b_project_… │
│ (Dropout)           │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5b_add (Add)   │ (None, 14, 14,    │          0 │ block5b_drop[0][… │
│                     │ 112)              │            │ block5a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_expand_conv │ (None, 14, 14,    │     75,264 │ block5b_add[0][0] │
│ (Conv2D)            │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_expand_bn   │ (None, 14, 14,    │      2,688 │ block5c_expand_c… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_expand_act… │ (None, 14, 14,    │          0 │ block5c_expand_b… │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_dwconv      │ (None, 14, 14,    │     16,800 │ block5c_expand_a… │
│ (DepthwiseConv2D)   │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_bn          │ (None, 14, 14,    │      2,688 │ block5c_dwconv[0… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_activation  │ (None, 14, 14,    │          0 │ block5c_bn[0][0]  │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_squeeze  │ (None, 672)       │          0 │ block5c_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_reshape  │ (None, 1, 1, 672) │          0 │ block5c_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_reduce   │ (None, 1, 1, 28)  │     18,844 │ block5c_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_expand   │ (None, 1, 1, 672) │     19,488 │ block5c_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_se_excite   │ (None, 14, 14,    │          0 │ block5c_activati… │
│ (Multiply)          │ 672)              │            │ block5c_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_project_co… │ (None, 14, 14,    │     75,264 │ block5c_se_excit… │
│ (Conv2D)            │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_project_bn  │ (None, 14, 14,    │        448 │ block5c_project_… │
│ (BatchNormalizatio… │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_drop        │ (None, 14, 14,    │          0 │ block5c_project_… │
│ (Dropout)           │ 112)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block5c_add (Add)   │ (None, 14, 14,    │          0 │ block5c_drop[0][… │
│                     │ 112)              │            │ block5b_add[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_expand_conv │ (None, 14, 14,    │     75,264 │ block5c_add[0][0] │
│ (Conv2D)            │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_expand_bn   │ (None, 14, 14,    │      2,688 │ block6a_expand_c… │
│ (BatchNormalizatio… │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_expand_act… │ (None, 14, 14,    │          0 │ block6a_expand_b… │
│ (Activation)        │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_dwconv_pad  │ (None, 17, 17,    │          0 │ block6a_expand_a… │
│ (ZeroPadding2D)     │ 672)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_dwconv      │ (None, 7, 7, 672) │     16,800 │ block6a_dwconv_p… │
│ (DepthwiseConv2D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_bn          │ (None, 7, 7, 672) │      2,688 │ block6a_dwconv[0… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_activation  │ (None, 7, 7, 672) │          0 │ block6a_bn[0][0]  │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_squeeze  │ (None, 672)       │          0 │ block6a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_reshape  │ (None, 1, 1, 672) │          0 │ block6a_se_squee… │
│ (Reshape)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_reduce   │ (None, 1, 1, 28)  │     18,844 │ block6a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_expand   │ (None, 1, 1, 672) │     19,488 │ block6a_se_reduc… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_se_excite   │ (None, 7, 7, 672) │          0 │ block6a_activati… │
│ (Multiply)          │                   │            │ block6a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_project_co… │ (None, 7, 7, 192) │    129,024 │ block6a_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6a_project_bn  │ (None, 7, 7, 192) │        768 │ block6a_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_expand_conv │ (None, 7, 7,      │    221,184 │ block6a_project_… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_expand_bn   │ (None, 7, 7,      │      4,608 │ block6b_expand_c… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_expand_act… │ (None, 7, 7,      │          0 │ block6b_expand_b… │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_dwconv      │ (None, 7, 7,      │     28,800 │ block6b_expand_a… │
│ (DepthwiseConv2D)   │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_bn          │ (None, 7, 7,      │      4,608 │ block6b_dwconv[0… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_activation  │ (None, 7, 7,      │          0 │ block6b_bn[0][0]  │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_squeeze  │ (None, 1152)      │          0 │ block6b_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_reshape  │ (None, 1, 1,      │          0 │ block6b_se_squee… │
│ (Reshape)           │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_reduce   │ (None, 1, 1, 48)  │     55,344 │ block6b_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_expand   │ (None, 1, 1,      │     56,448 │ block6b_se_reduc… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_se_excite   │ (None, 7, 7,      │          0 │ block6b_activati… │
│ (Multiply)          │ 1152)             │            │ block6b_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_project_co… │ (None, 7, 7, 192) │    221,184 │ block6b_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_project_bn  │ (None, 7, 7, 192) │        768 │ block6b_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_drop        │ (None, 7, 7, 192) │          0 │ block6b_project_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6b_add (Add)   │ (None, 7, 7, 192) │          0 │ block6b_drop[0][… │
│                     │                   │            │ block6a_project_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_expand_conv │ (None, 7, 7,      │    221,184 │ block6b_add[0][0] │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_expand_bn   │ (None, 7, 7,      │      4,608 │ block6c_expand_c… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_expand_act… │ (None, 7, 7,      │          0 │ block6c_expand_b… │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_dwconv      │ (None, 7, 7,      │     28,800 │ block6c_expand_a… │
│ (DepthwiseConv2D)   │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_bn          │ (None, 7, 7,      │      4,608 │ block6c_dwconv[0… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_activation  │ (None, 7, 7,      │          0 │ block6c_bn[0][0]  │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_squeeze  │ (None, 1152)      │          0 │ block6c_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_reshape  │ (None, 1, 1,      │          0 │ block6c_se_squee… │
│ (Reshape)           │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_reduce   │ (None, 1, 1, 48)  │     55,344 │ block6c_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_expand   │ (None, 1, 1,      │     56,448 │ block6c_se_reduc… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_se_excite   │ (None, 7, 7,      │          0 │ block6c_activati… │
│ (Multiply)          │ 1152)             │            │ block6c_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_project_co… │ (None, 7, 7, 192) │    221,184 │ block6c_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_project_bn  │ (None, 7, 7, 192) │        768 │ block6c_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_drop        │ (None, 7, 7, 192) │          0 │ block6c_project_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6c_add (Add)   │ (None, 7, 7, 192) │          0 │ block6c_drop[0][… │
│                     │                   │            │ block6b_add[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_expand_conv │ (None, 7, 7,      │    221,184 │ block6c_add[0][0] │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_expand_bn   │ (None, 7, 7,      │      4,608 │ block6d_expand_c… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_expand_act… │ (None, 7, 7,      │          0 │ block6d_expand_b… │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_dwconv      │ (None, 7, 7,      │     28,800 │ block6d_expand_a… │
│ (DepthwiseConv2D)   │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_bn          │ (None, 7, 7,      │      4,608 │ block6d_dwconv[0… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_activation  │ (None, 7, 7,      │          0 │ block6d_bn[0][0]  │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_squeeze  │ (None, 1152)      │          0 │ block6d_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_reshape  │ (None, 1, 1,      │          0 │ block6d_se_squee… │
│ (Reshape)           │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_reduce   │ (None, 1, 1, 48)  │     55,344 │ block6d_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_expand   │ (None, 1, 1,      │     56,448 │ block6d_se_reduc… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_se_excite   │ (None, 7, 7,      │          0 │ block6d_activati… │
│ (Multiply)          │ 1152)             │            │ block6d_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_project_co… │ (None, 7, 7, 192) │    221,184 │ block6d_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_project_bn  │ (None, 7, 7, 192) │        768 │ block6d_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_drop        │ (None, 7, 7, 192) │          0 │ block6d_project_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block6d_add (Add)   │ (None, 7, 7, 192) │          0 │ block6d_drop[0][… │
│                     │                   │            │ block6c_add[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_expand_conv │ (None, 7, 7,      │    221,184 │ block6d_add[0][0] │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_expand_bn   │ (None, 7, 7,      │      4,608 │ block7a_expand_c… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_expand_act… │ (None, 7, 7,      │          0 │ block7a_expand_b… │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_dwconv      │ (None, 7, 7,      │     10,368 │ block7a_expand_a… │
│ (DepthwiseConv2D)   │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_bn          │ (None, 7, 7,      │      4,608 │ block7a_dwconv[0… │
│ (BatchNormalizatio… │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_activation  │ (None, 7, 7,      │          0 │ block7a_bn[0][0]  │
│ (Activation)        │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_squeeze  │ (None, 1152)      │          0 │ block7a_activati… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_reshape  │ (None, 1, 1,      │          0 │ block7a_se_squee… │
│ (Reshape)           │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_reduce   │ (None, 1, 1, 48)  │     55,344 │ block7a_se_resha… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_expand   │ (None, 1, 1,      │     56,448 │ block7a_se_reduc… │
│ (Conv2D)            │ 1152)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_se_excite   │ (None, 7, 7,      │          0 │ block7a_activati… │
│ (Multiply)          │ 1152)             │            │ block7a_se_expan… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_project_co… │ (None, 7, 7, 320) │    368,640 │ block7a_se_excit… │
│ (Conv2D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ block7a_project_bn  │ (None, 7, 7, 320) │      1,280 │ block7a_project_… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ top_conv (Conv2D)   │ (None, 7, 7,      │    409,600 │ block7a_project_… │
│                     │ 1280)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ top_bn              │ (None, 7, 7,      │      5,120 │ top_conv[0][0]    │
│ (BatchNormalizatio… │ 1280)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ top_activation      │ (None, 7, 7,      │          0 │ top_bn[0][0]      │
│ (Activation)        │ 1280)             │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ global_average_poo… │ (None, 1280)      │          0 │ top_activation[0… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_1 (Dropout) │ (None, 1280)      │          0 │ global_average_p… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_1 (Dense)     │ (None, 5)         │      6,405 │ dropout_1[0][0]   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 4,055,976 (15.47 MB)
 Trainable params: 4,013,953 (15.31 MB)
 Non-trainable params: 42,023 (164.16 KB)

4.4 Model compilation¶

Loss function¶

CategoricalCrossentropy() : Categorical Cross Entropy is also known as Softmax Loss. It's a softmax activation plus a Cross-Entropy loss used for multiclass classification. Using this loss, we can train a Convolutional Neural Network to output a probability over the N classes for each image.

Optimizer Function¶

Adam() : The Adam optimizer (short for Adaptive Moment Estimation) is a widely used optimization algorithm in neural networks that combines the advantages of two popular optimization methods: Momentum and RMSProp. It is particularly effective for handling sparse gradients, noisy objectives, and non-stationary problems.

Metrics¶

Accuracy() : It is the ratio of number of correct predictions to the total number of input samples.

In [39]:
model.compile(loss='categorical_crossentropy',optimizer = 'Adam', metrics= ['accuracy'])

4.5 Getting the callback functions ready¶

Callback Functions¶

TensorBoard()

  • TensorBoard is a visualization tool provided by TensorFlow. It allows us to monitor various metrics like loss, accuracy, learning rate, and more during the training process.
  • In this setup:
    • The logs are saved in a directory named logs, which can be used to visualize training progress in TensorBoard.

ModelCheckpoint()

  • This callback is used to save the model during training. It ensures the best version of the model (based on a monitored metric) is saved automatically.
  • In this setup:
    • The model is saved with the filename cashew-effnet.keras.
    • The metric being monitored is val_accuracy (validation accuracy).
    • The model is saved only when it achieves the highest validation accuracy so far.
    • save_best_only=True ensures only the best model is saved.
    • verbose=1 provides updates during the saving process.

ReduceLROnPlateau()

  • This callback reduces the learning rate when the monitored metric stops improving, helping the model converge better during later training stages.
  • In this setup:
    • The metric being monitored is val_accuracy.
    • If the validation accuracy does not improve for 2 epochs, the learning rate is reduced by a factor of 0.3.
    • min_delta=0.001 sets the minimum change in validation accuracy required to qualify as an improvement.
    • verbose=1 provides updates when the learning rate is reduced.
In [40]:
tensorboard = TensorBoard(log_dir = 'logs')
checkpoint = ModelCheckpoint("cashew-effnet.keras",monitor="val_accuracy",save_best_only=True,mode="auto",verbose=1)
reduce_lr = ReduceLROnPlateau(monitor = 'val_accuracy', factor = 0.3, patience = 2, min_delta = 0.001,
                              mode='auto',verbose=1)

4.6 Model fitting¶

Model.fit()¶

The fit() method is used to train the model on the training data (X_train, y_train) for a specified number of epochs and batch size. It also incorporates the callbacks defined earlier to enhance and monitor the training process.

Key Parameters¶

  • X_train, y_train

    • X_train: The input training data (features).
    • y_train: The corresponding labels for the training data.
  • validation_split=0.2

    • This reserves 20% of the training data for validation to evaluate the model’s performance during training.
  • epochs=12

    • Specifies that the model will train for 12 iterations over the entire training dataset.
  • verbose=1

    • Displays detailed training progress, including loss, accuracy, and other metrics per epoch.
  • batch_size=32

    • The model processes 32 samples at a time before updating the weights, which helps balance memory usage and speed.
  • callbacks=[tensorboard, checkpoint, reduce_lr]

    • Integrates the previously defined callbacks to monitor and optimize the training process:
      • TensorBoard: Logs metrics for visualization in TensorBoard.
      • ModelCheckpoint: Saves the best model based on validation accuracy.
      • ReduceLROnPlateau: Dynamically adjusts the learning rate to improve convergence.

By combining these settings, the training process is efficient, monitored, and adaptable, ensuring better performance.

In [41]:
history = model.fit(X_train,y_train,validation_split=0.2, epochs =12, verbose=1, batch_size=32,
                   callbacks=[tensorboard,checkpoint,reduce_lr])
Epoch 1/12
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1732167081.443094     778 service.cc:145] XLA service 0x7e07800a1590 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1732167081.443153     778 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1732167081.443157     778 service.cc:153]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1732167138.351220     778 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 126ms/step - accuracy: 0.8518 - loss: 0.4263
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1732167254.390334     778 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'input_multiply_reduce_fusion', 160 bytes spill stores, 160 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_reduce_fusion', 24 bytes spill stores, 24 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_reduce_fusion_2', 24 bytes spill stores, 24 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_reduce_fusion_3', 16 bytes spill stores, 16 bytes spill loads

517/517 ━━━━━━━━━━━━━━━━━━━━ 0s 223ms/step - accuracy: 0.8519 - loss: 0.4260
Epoch 1: val_accuracy improved from -inf to 0.93923, saving model to cashew-effnet.keras
517/517 ━━━━━━━━━━━━━━━━━━━━ 221s 252ms/step - accuracy: 0.8520 - loss: 0.4258 - val_accuracy: 0.9392 - val_loss: 0.2125 - learning_rate: 0.0010
Epoch 2/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 128ms/step - accuracy: 0.9540 - loss: 0.1572
Epoch 2: val_accuracy improved from 0.93923 to 0.95182, saving model to cashew-effnet.keras
517/517 ━━━━━━━━━━━━━━━━━━━━ 71s 137ms/step - accuracy: 0.9540 - loss: 0.1572 - val_accuracy: 0.9518 - val_loss: 0.1774 - learning_rate: 0.0010
Epoch 3/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 132ms/step - accuracy: 0.9620 - loss: 0.1243
Epoch 3: val_accuracy improved from 0.95182 to 0.96295, saving model to cashew-effnet.keras
517/517 ━━━━━━━━━━━━━━━━━━━━ 73s 142ms/step - accuracy: 0.9619 - loss: 0.1243 - val_accuracy: 0.9630 - val_loss: 0.1312 - learning_rate: 0.0010
Epoch 4/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 134ms/step - accuracy: 0.9693 - loss: 0.0925
Epoch 4: val_accuracy did not improve from 0.96295
517/517 ━━━━━━━━━━━━━━━━━━━━ 74s 143ms/step - accuracy: 0.9693 - loss: 0.0925 - val_accuracy: 0.9465 - val_loss: 0.1622 - learning_rate: 0.0010
Epoch 5/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 136ms/step - accuracy: 0.9676 - loss: 0.0970
Epoch 5: val_accuracy improved from 0.96295 to 0.97361, saving model to cashew-effnet.keras
517/517 ━━━━━━━━━━━━━━━━━━━━ 75s 146ms/step - accuracy: 0.9676 - loss: 0.0970 - val_accuracy: 0.9736 - val_loss: 0.0861 - learning_rate: 0.0010
Epoch 6/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9775 - loss: 0.0670
Epoch 6: val_accuracy did not improve from 0.97361
517/517 ━━━━━━━━━━━━━━━━━━━━ 75s 145ms/step - accuracy: 0.9775 - loss: 0.0671 - val_accuracy: 0.9487 - val_loss: 0.1748 - learning_rate: 0.0010
Epoch 7/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9720 - loss: 0.0900
Epoch 7: val_accuracy did not improve from 0.97361

Epoch 7: ReduceLROnPlateau reducing learning rate to 0.0003000000142492354.
517/517 ━━━━━━━━━━━━━━━━━━━━ 75s 145ms/step - accuracy: 0.9720 - loss: 0.0900 - val_accuracy: 0.9586 - val_loss: 0.1304 - learning_rate: 0.0010
Epoch 8/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9886 - loss: 0.0352
Epoch 8: val_accuracy improved from 0.97361 to 0.97893, saving model to cashew-effnet.keras
517/517 ━━━━━━━━━━━━━━━━━━━━ 76s 147ms/step - accuracy: 0.9887 - loss: 0.0352 - val_accuracy: 0.9789 - val_loss: 0.0635 - learning_rate: 3.0000e-04
Epoch 9/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9965 - loss: 0.0137
Epoch 9: val_accuracy improved from 0.97893 to 0.98160, saving model to cashew-effnet.keras
517/517 ━━━━━━━━━━━━━━━━━━━━ 76s 147ms/step - accuracy: 0.9965 - loss: 0.0137 - val_accuracy: 0.9816 - val_loss: 0.0708 - learning_rate: 3.0000e-04
Epoch 10/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9976 - loss: 0.0069
Epoch 10: val_accuracy improved from 0.98160 to 0.98184, saving model to cashew-effnet.keras
517/517 ━━━━━━━━━━━━━━━━━━━━ 76s 147ms/step - accuracy: 0.9976 - loss: 0.0069 - val_accuracy: 0.9818 - val_loss: 0.0775 - learning_rate: 3.0000e-04
Epoch 11/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9985 - loss: 0.0047
Epoch 11: val_accuracy did not improve from 0.98184

Epoch 11: ReduceLROnPlateau reducing learning rate to 9.000000427477062e-05.
517/517 ━━━━━━━━━━━━━━━━━━━━ 75s 145ms/step - accuracy: 0.9985 - loss: 0.0047 - val_accuracy: 0.9809 - val_loss: 0.0926 - learning_rate: 3.0000e-04
Epoch 12/12
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9988 - loss: 0.0033
Epoch 12: val_accuracy improved from 0.98184 to 0.98329, saving model to cashew-effnet.keras
517/517 ━━━━━━━━━━━━━━━━━━━━ 76s 147ms/step - accuracy: 0.9988 - loss: 0.0033 - val_accuracy: 0.9833 - val_loss: 0.0817 - learning_rate: 9.0000e-05

4.7 Performance Analysis¶

In [42]:
epochs = [i for i in range(12)]
fig, ax = plt.subplots(1,2,figsize=(14,7))
train_acc = history.history['accuracy']
train_loss = history.history['loss']
val_acc = history.history['val_accuracy']
val_loss = history.history['val_loss']

fig.text(s='Epochs vs. Training and Validation Accuracy/Loss',size=18,fontweight='bold',
             fontname='monospace',color=colors_dark[1],y=1,x=0.28,alpha=0.8)

sns.despine()
ax[0].plot(epochs, train_acc, marker='o',markerfacecolor=colors_green[2],color=colors_green[3],
           label = 'Training Accuracy')
ax[0].plot(epochs, val_acc, marker='o',markerfacecolor=colors_red[2],color=colors_red[3],
           label = 'Validation Accuracy')
ax[0].legend(frameon=False)
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Accuracy')

sns.despine()
ax[1].plot(epochs, train_loss, marker='o',markerfacecolor=colors_green[2],color=colors_green[3],
           label ='Training Loss')
ax[1].plot(epochs, val_loss, marker='o',markerfacecolor=colors_red[2],color=colors_red[3],
           label = 'Validation Loss')
ax[1].legend(frameon=False)
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Training & Validation Loss')

fig.show()
No description has been provided for this image

5. Performance analysis on test dataset¶

In [43]:
pred = model.predict(X_test)
pred = np.argmax(pred,axis=1)
y_test_new = np.argmax(y_test,axis=1)
162/162 ━━━━━━━━━━━━━━━━━━━━ 17s 75ms/step

5.1 Classification Report and confusion metrix¶

Classification Report¶

  • classification_report(y_test_new, pred)
    • The classification report provides a detailed summary of the model's performance on the test data.
    • Metrics included:
      • Precision: The proportion of true positive predictions out of all positive predictions.
      • Recall (Sensitivity): The proportion of true positive predictions out of actual positives.
      • F1-Score: The harmonic mean of precision and recall, balancing both metrics.
      • Support: The number of true instances for each class.

Confusion Matrix¶

  • confusion_matrix(y_test_new, pred)
    • The confusion matrix gives a tabular representation of the model's predictions versus the actual labels.
    • Rows represent the actual classes, and columns represent the predicted classes.
    • Components include:
      • True Positives (TP): Correctly predicted positive instances.
      • True Negatives (TN): Correctly predicted negative instances.
      • False Positives (FP): Incorrectly predicted positive instances.
      • False Negatives (FN): Incorrectly predicted negative instances.

These metrics allow us to evaluate how well the model classifies the test data, providing insights into strengths and areas for improvement.

In [47]:
cr = classification_report(y_test_new,pred)
cmetric = confusion_matrix(y_test_new,pred)
print(cr)
              precision    recall  f1-score   support

           0       0.99      0.99      0.99      1296
           1       0.96      0.95      0.96       995
           2       0.98      0.99      0.99      1453
           3       0.99      1.00      1.00       430
           4       0.97      0.96      0.96       989

    accuracy                           0.98      5163
   macro avg       0.98      0.98      0.98      5163
weighted avg       0.98      0.98      0.98      5163

5.2 Visualizing the classification report¶

In [48]:
raw_data = cr.split("\n")
precision = []
recall = []
f1_score = []
for line in raw_data[2:-5]:
    line = line.split()
    precision.append(float(line[1]))
    recall.append(float(line[2]))
    f1_score.append(float(line[3]))
In [55]:
x = np.arange(len(labels))
width = 0.2

colors = sns.color_palette("Set2", 3)

plt.figure(figsize=(12, 6))

plt.bar(x, precision, width, label='Precision', color=colors[0])
plt.bar(x + width, recall, width, label='Recall', color=colors[1])
plt.bar(x + 2 * width, f1_score, width, label='F1-score', color=colors[2])

for i, v in enumerate(precision):
    plt.text(i , v + 0.02, f"{v:.2f}", ha='center', fontsize=9)
for i, v in enumerate(recall):
    plt.text(i + width , v + 0.02, f"{v:.2f}", ha='center', fontsize=9)
for i, v in enumerate(f1_score):
    plt.text(i + 2 * width , v + 0.02, f"{v:.2f}", ha='center', fontsize=9)

plt.grid(axis='y', linestyle='--', alpha=0.7)

plt.xlabel('Class', fontsize=12)
plt.ylabel('Score', fontsize=12)
plt.title('Precision, Recall, and F1-score by Class', fontsize=14, fontweight='bold')
plt.xticks(x + width, labels, rotation=15, fontsize=10)
plt.axhline(y=1, color='red', linestyle='--', linewidth=1, label='Highest')

plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1), fontsize=10)

plt.tight_layout()
plt.show()
No description has been provided for this image

5.3 Visualizing the Confusion Matrix¶

In [56]:
fig, ax = plt.subplots(figsize=(10, 8))

CLASS_NAMES = labels

disp = ConfusionMatrixDisplay(confusion_matrix=cmetric, display_labels=CLASS_NAMES)
disp.plot(ax=ax, cmap='binary')


ax.set_xticks(range(len(CLASS_NAMES))) 
ax.set_xticklabels(CLASS_NAMES, rotation=90) 

ax.set_yticks(range(len(CLASS_NAMES)))  
ax.set_yticklabels(CLASS_NAMES)  

for text in ax.texts:
    text.set_fontsize(12)

ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')

plt.grid(False)
plt.show()
No description has been provided for this image

5.4 Advanced visualizations¶

In [57]:
test_predictions = model.predict(X_test)
162/162 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step
In [58]:
y_val_indexes = np.array([val.argmax() for val in tqdm(y_test)])
y_val_indexes[:5]
100%|██████████| 5163/5163 [00:00<00:00, 990676.22it/s]
Out[58]:
array([0, 0, 4, 4, 2])
In [59]:
y_test_classes = np.argmax(test_predictions, axis=1)
y_test_classes[0:5]
Out[59]:
array([0, 0, 4, 4, 2])
In [60]:
y_test_classes.shape, y_val_indexes.shape
Out[60]:
((5163,), (5163,))
In [61]:
y_true = y_val_indexes
accuracy = accuracy_score(y_true, y_test_classes)
accuracy
Out[61]:
0.9783071857447221
In [62]:
index = 40
print(test_predictions[index])
print(f"Max value (probability of prediction): {np.max(test_predictions[index])}")
print(f"Sum: {np.sum(test_predictions[index])}")
print(f"Max index: {np.argmax(test_predictions[index])}")
print(f"Predicted label: {CLASS_NAMES[np.argmax(test_predictions[index])]}")
[3.4615238e-10 8.7412627e-10 7.2288189e-09 4.4301933e-07 9.9999952e-01]
Max value (probability of prediction): 0.9999995231628418
Sum: 1.0
Max index: 4
Predicted label: leaf miner
In [63]:
CLASS_NAMES[4]
Out[63]:
'leaf miner'
In [64]:
# Turn prediction probabilities into their respective label (easier to understand)
def get_pred_label(prediction_probabilities):
  """
  Turns an array of prediction probabilities into a label.
  """
  return CLASS_NAMES[np.argmax(prediction_probabilities)]

# Get a predicted label based on an array of prediction probabilities
pred_label = get_pred_label(test_predictions[81])
pred_label
Out[64]:
'red rust'
In [66]:
def plot_pred(prediction_probabilities, labels, images, n=1):
  """
  View the prediction, ground truth and image for sample n
  """
  pred_prob, true_label, image = prediction_probabilities[n], CLASS_NAMES[labels[n]], images[n]

  # Get the pred label
  pred_label = get_pred_label(pred_prob)

  # Plot image & remove ticks
  plt.imshow(image.astype("uint8"))
  plt.xticks([])
  plt.yticks([])
    
  # Change the colour of the title depending on if the prediction is right or wrong
  if pred_label == true_label:
    color = "green"
  else:
    color = "red"
  
  # Change plot title to be predicted, probability of prediction and truth label
  plt.title("{} {:2.0f}% {}".format(pred_label,
                                    np.max(pred_prob)*100,
                                    true_label),
                                    color=color)

plot_pred() : Provides visualization of model Prediction with confidence percentage and actual Class Name / label.

In [67]:
plot_pred(prediction_probabilities=test_predictions,
          labels=y_val_indexes,
          images=X_test,
          n=2400)
plt.show()
No description has been provided for this image
In [68]:
def plot_pred_conf(prediction_probabilities, labels, n=1):
  """
  Plus the top 10 highest prediction confidences along with the truth label for sample n.
  """
  pred_prob, true_label = prediction_probabilities[n], labels[n]

  # Get the predicted label
  pred_label = get_pred_label(pred_prob)
  # Find the top  prediction confidence indexes
  top_10_pred_indexes = pred_prob.argsort()[::-1]

#   print(top_10_pred_indexes)
  # Find the top 10 prediction confidence values
  top_10_pred_values = pred_prob[top_10_pred_indexes]
  # Find the top 10 prediction labels
  top_10_pred_labels = top_10_pred_indexes
  # Setup plot
  top_plot = plt.bar(np.arange(len(top_10_pred_labels)),
                     top_10_pred_values,
                     color="grey")
  plt.xticks(np.arange(len(top_10_pred_labels)),
             labels=top_10_pred_labels,
             rotation="vertical")
  
  # Change color of true label
  if np.isin(true_label, top_10_pred_labels):
    top_plot[np.argmax(top_10_pred_labels == true_label)].set_color("green")
  else:
    pass
In [69]:
plot_pred_conf(prediction_probabilities = test_predictions,
               labels = y_val_indexes,
               n = 8)
No description has been provided for this image
In [70]:
i_multiplier = 20
num_rows = 10
num_cols = 2
num_images = num_rows * num_cols
plt.figure(figsize=(10 * num_cols, 5 * num_rows))
for i in range(num_images):
  plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
  plot_pred(prediction_probabilities=test_predictions,
            labels=y_val_indexes,
            images=X_test,
            n=i+i_multiplier)
  plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
  plot_pred_conf(prediction_probabilities=test_predictions,
                 labels=y_val_indexes,
                 n=i+i_multiplier)
plt.tight_layout(h_pad=1.0)
plt.show()
No description has been provided for this image